Skip to content

fix(light/linear): mark pytree_token as ephemeral via __getstate__/__setstate__#374

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/pytree-token-ephemeral
Apr 28, 2026
Merged

fix(light/linear): mark pytree_token as ephemeral via __getstate__/__setstate__#374
Jammy2211 merged 1 commit into
mainfrom
feature/pytree-token-ephemeral

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Mark LightProfileLinear.pytree_token as runtime-only via Python's standard __getstate__ / __setstate__ protocol, so it doesn't get lossy-int-to-float persisted through PyAutoFit's database round-trip.

Why

The CI Smoke Tests jobs on autolens_workspace_test/main (both 3.12 and 3.13) have been failing with:

TypeError: __hash__ method should return an integer
  File "autogalaxy/abstract_fit.py", line 133, in linear_light_profile_intensity_dict
    linear_light_profile_intensity_dict[light_profile] = float(reconstruction[i])

Root cause: pytree_token (added in 29b16b47 to give a stable identity across jax.jit flatten/unflatten) is set in __init__ to next(itertools.count()) — an int. When a fit is loaded back via FitImagingAgg, PyAutoFit's database deserializer routes every numeric attribute through the Value SQL row (sa.Float column in instance.py:96), so the saved int comes back as a float (0.0 / 1.0 observed in live trace). Python ≥3.12 strictly rejects non-int returns from __hash__.

What this PR does

pytree_token is a process-local counter increment — not meaningful state to persist (copying a number from one process's counter into another's preserves nothing). Adding __getstate__ / __setstate__ is the canonical Python idiom for marking an attribute as ephemeral.

def __getstate__(self):
    return {k: v for k, v in self.__dict__.items() if k != "pytree_token"}

def __setstate__(self, state):
    self.__dict__.update(state)
    if "pytree_token" not in state:
        self.pytree_token = next(LightProfileLinear._pytree_token_counter)

PyAutoFit's Instance._from_object (at instance.py:73) and Object.__call__ (at model.py:189) already honour these methods. No PyAutoFit changes required. The JAX path uses register_model's attr_const which reads vars(self) directly and is unaffected by __getstate__.

Why not patch PyAutoFit's database serializer

A library-wide fix in PyAutoFit (e.g. an IntValue SQL table) would preserve every persisted int — but pytree_token is the only known consumer of int-typed persisted state whose __hash__ breaks. Galaxy.__hash__ already does int(self.id), GeometryProfile.__hash__ returns id(self). The conceptually-correct fix lives at the layer where the attribute is defined: this attribute should never have been persisted.

Tests added

5 new unit tests in test_autogalaxy/profiles/light/linear/test_abstract.py:

  • test__pytree_token_is_int_and_unique
  • test__getstate__omits_pytree_token
  • test__setstate__assigns_fresh_pytree_token_when_missing
  • test__pickle_roundtrip_preserves_int_hash
  • test__setstate__preserves_pytree_token_when_present

All pure-numpy (no JAX imports), per the project's library-test policy.

Verified locally

  • pytest test_autogalaxy/845 passed (was 840 + 5 new)
  • python autolens_workspace_test/scripts/database/scrape/general.py → exit 0, prints "FitImagingAgg Checked" with no traceback

Out of scope

  • The AssertionError failures in autolens_workspace_test/scripts/jax_likelihood_functions/... smoke tests — separate root cause, separate PR.

Test plan

  • PyAutoGalaxy CI Tests workflow goes green on this PR
  • After merge: re-trigger autolens_workspace_test Smoke Tests on main and confirm TypeError: __hash__ method should return an integer is gone

🤖 Generated with Claude Code

…setstate__

pytree_token is a process-local counter increment used as a stable
hash/eq identity for LightProfileLinear instances across jax.jit
flatten/unflatten. It is not meaningful state to persist — copying a
number from one process's counter into another's does not preserve
identity across processes.

PyAutoFit's database serializer routes every numeric attribute through
the Value SQL row (sa.Float column), so persisting pytree_token as an
int and reading it back yields a float. Python >=3.12 strictly requires
__hash__ to return int, so any dict keyed on a LightProfileLinear (e.g.
the linear_light_profile_intensity_dict in autogalaxy/abstract_fit.py)
raised TypeError on the visualization path after a fit was loaded via
FitImagingAgg.

Add __getstate__ that omits pytree_token and __setstate__ that assigns
a fresh value if missing. PyAutoFit's Instance._from_object and
Object.__call__ already check for these methods and honour them, so no
PyAutoFit changes are required. The JAX-jit path uses register_model's
attr_const which reads vars(self) directly and is unaffected.

Fixes the smoke-test failure on autolens_workspace_test/main:
scripts/database/scrape/general.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 0dcea47 into main Apr 28, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/pytree-token-ephemeral branch April 28, 2026 12:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant